<!DOCTYPE html>

<html>

<head>

<meta charset="utf-8" />
<meta name="generator" content="pandoc" />
<meta http-equiv="X-UA-Compatible" content="IE=EDGE" />

<meta name="viewport" content="width=device-width, initial-scale=1" />



<title>Loading data</title>

<script src="data:application/javascript;base64,Ly8gUGFuZG9jIDIuOSBhZGRzIGF0dHJpYnV0ZXMgb24gYm90aCBoZWFkZXIgYW5kIGRpdi4gV2UgcmVtb3ZlIHRoZSBmb3JtZXIgKHRvCi8vIGJlIGNvbXBhdGlibGUgd2l0aCB0aGUgYmVoYXZpb3Igb2YgUGFuZG9jIDwgMi44KS4KZG9jdW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignRE9NQ29udGVudExvYWRlZCcsIGZ1bmN0aW9uKGUpIHsKICB2YXIgaHMgPSBkb2N1bWVudC5xdWVyeVNlbGVjdG9yQWxsKCJkaXYuc2VjdGlvbltjbGFzcyo9J2xldmVsJ10gPiA6Zmlyc3QtY2hpbGQiKTsKICB2YXIgaSwgaCwgYTsKICBmb3IgKGkgPSAwOyBpIDwgaHMubGVuZ3RoOyBpKyspIHsKICAgIGggPSBoc1tpXTsKICAgIGlmICghL15oWzEtNl0kL2kudGVzdChoLnRhZ05hbWUpKSBjb250aW51ZTsgIC8vIGl0IHNob3VsZCBiZSBhIGhlYWRlciBoMS1oNgogICAgYSA9IGguYXR0cmlidXRlczsKICAgIHdoaWxlIChhLmxlbmd0aCA+IDApIGgucmVtb3ZlQXR0cmlidXRlKGFbMF0ubmFtZSk7CiAgfQp9KTsK"></script>

<style type="text/css">
  code{white-space: pre-wrap;}
  span.smallcaps{font-variant: small-caps;}
  span.underline{text-decoration: underline;}
  div.column{display: inline-block; vertical-align: top; width: 50%;}
  div.hanging-indent{margin-left: 1.5em; text-indent: -1.5em;}
  ul.task-list{list-style: none;}
    </style>


<style type="text/css">
  code {
    white-space: pre;
  }
  .sourceCode {
    overflow: visible;
  }
</style>
<style type="text/css" data-origin="pandoc">
pre > code.sourceCode { white-space: pre; position: relative; }
pre > code.sourceCode > span { display: inline-block; line-height: 1.25; }
pre > code.sourceCode > span:empty { height: 1.2em; }
.sourceCode { overflow: visible; }
code.sourceCode > span { color: inherit; text-decoration: inherit; }
div.sourceCode { margin: 1em 0; }
pre.sourceCode { margin: 0; }
@media screen {
div.sourceCode { overflow: auto; }
}
@media print {
pre > code.sourceCode { white-space: pre-wrap; }
pre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }
}
pre.numberSource code
  { counter-reset: source-line 0; }
pre.numberSource code > span
  { position: relative; left: -4em; counter-increment: source-line; }
pre.numberSource code > span > a:first-child::before
  { content: counter(source-line);
    position: relative; left: -1em; text-align: right; vertical-align: baseline;
    border: none; display: inline-block;
    -webkit-touch-callout: none; -webkit-user-select: none;
    -khtml-user-select: none; -moz-user-select: none;
    -ms-user-select: none; user-select: none;
    padding: 0 4px; width: 4em;
    color: #aaaaaa;
  }
pre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }
div.sourceCode
  {   }
@media screen {
pre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }
}
code span.al { color: #ff0000; font-weight: bold; } /* Alert */
code span.an { color: #60a0b0; font-weight: bold; font-style: italic; } /* Annotation */
code span.at { color: #7d9029; } /* Attribute */
code span.bn { color: #40a070; } /* BaseN */
code span.bu { } /* BuiltIn */
code span.cf { color: #007020; font-weight: bold; } /* ControlFlow */
code span.ch { color: #4070a0; } /* Char */
code span.cn { color: #880000; } /* Constant */
code span.co { color: #60a0b0; font-style: italic; } /* Comment */
code span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } /* CommentVar */
code span.do { color: #ba2121; font-style: italic; } /* Documentation */
code span.dt { color: #902000; } /* DataType */
code span.dv { color: #40a070; } /* DecVal */
code span.er { color: #ff0000; font-weight: bold; } /* Error */
code span.ex { } /* Extension */
code span.fl { color: #40a070; } /* Float */
code span.fu { color: #06287e; } /* Function */
code span.im { } /* Import */
code span.in { color: #60a0b0; font-weight: bold; font-style: italic; } /* Information */
code span.kw { color: #007020; font-weight: bold; } /* Keyword */
code span.op { color: #666666; } /* Operator */
code span.ot { color: #007020; } /* Other */
code span.pp { color: #bc7a00; } /* Preprocessor */
code span.sc { color: #4070a0; } /* SpecialChar */
code span.ss { color: #bb6688; } /* SpecialString */
code span.st { color: #4070a0; } /* String */
code span.va { color: #19177c; } /* Variable */
code span.vs { color: #4070a0; } /* VerbatimString */
code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warning */

</style>
<script>
// apply pandoc div.sourceCode style to pre.sourceCode instead
(function() {
  var sheets = document.styleSheets;
  for (var i = 0; i < sheets.length; i++) {
    if (sheets[i].ownerNode.dataset["origin"] !== "pandoc") continue;
    try { var rules = sheets[i].cssRules; } catch (e) { continue; }
    for (var j = 0; j < rules.length; j++) {
      var rule = rules[j];
      // check if there is a div.sourceCode rule
      if (rule.type !== rule.STYLE_RULE || rule.selectorText !== "div.sourceCode") continue;
      var style = rule.style.cssText;
      // check if color or background-color is set
      if (rule.style.color === '' && rule.style.backgroundColor === '') continue;
      // replace div.sourceCode by a pre.sourceCode rule
      sheets[i].deleteRule(j);
      sheets[i].insertRule('pre.sourceCode{' + style + '}', j);
    }
  }
})();
</script>




<link rel="stylesheet" href="data:text/css,body%20%7B%0Abackground%2Dcolor%3A%20%23fff%3B%0Amargin%3A%201em%20auto%3B%0Amax%2Dwidth%3A%20700px%3B%0Aoverflow%3A%20visible%3B%0Apadding%2Dleft%3A%202em%3B%0Apadding%2Dright%3A%202em%3B%0Afont%2Dfamily%3A%20%22Open%20Sans%22%2C%20%22Helvetica%20Neue%22%2C%20Helvetica%2C%20Arial%2C%20sans%2Dserif%3B%0Afont%2Dsize%3A%2014px%3B%0Aline%2Dheight%3A%201%2E35%3B%0A%7D%0A%23TOC%20%7B%0Aclear%3A%20both%3B%0Amargin%3A%200%200%2010px%2010px%3B%0Apadding%3A%204px%3B%0Awidth%3A%20400px%3B%0Aborder%3A%201px%20solid%20%23CCCCCC%3B%0Aborder%2Dradius%3A%205px%3B%0Abackground%2Dcolor%3A%20%23f6f6f6%3B%0Afont%2Dsize%3A%2013px%3B%0Aline%2Dheight%3A%201%2E3%3B%0A%7D%0A%23TOC%20%2Etoctitle%20%7B%0Afont%2Dweight%3A%20bold%3B%0Afont%2Dsize%3A%2015px%3B%0Amargin%2Dleft%3A%205px%3B%0A%7D%0A%23TOC%20ul%20%7B%0Apadding%2Dleft%3A%2040px%3B%0Amargin%2Dleft%3A%20%2D1%2E5em%3B%0Amargin%2Dtop%3A%205px%3B%0Amargin%2Dbottom%3A%205px%3B%0A%7D%0A%23TOC%20ul%20ul%20%7B%0Amargin%2Dleft%3A%20%2D2em%3B%0A%7D%0A%23TOC%20li%20%7B%0Aline%2Dheight%3A%2016px%3B%0A%7D%0Atable%20%7B%0Amargin%3A%201em%20auto%3B%0Aborder%2Dwidth%3A%201px%3B%0Aborder%2Dcolor%3A%20%23DDDDDD%3B%0Aborder%2Dstyle%3A%20outset%3B%0Aborder%2Dcollapse%3A%20collapse%3B%0A%7D%0Atable%20th%20%7B%0Aborder%2Dwidth%3A%202px%3B%0Apadding%3A%205px%3B%0Aborder%2Dstyle%3A%20inset%3B%0A%7D%0Atable%20td%20%7B%0Aborder%2Dwidth%3A%201px%3B%0Aborder%2Dstyle%3A%20inset%3B%0Aline%2Dheight%3A%2018px%3B%0Apadding%3A%205px%205px%3B%0A%7D%0Atable%2C%20table%20th%2C%20table%20td%20%7B%0Aborder%2Dleft%2Dstyle%3A%20none%3B%0Aborder%2Dright%2Dstyle%3A%20none%3B%0A%7D%0Atable%20thead%2C%20table%20tr%2Eeven%20%7B%0Abackground%2Dcolor%3A%20%23f7f7f7%3B%0A%7D%0Ap%20%7B%0Amargin%3A%200%2E5em%200%3B%0A%7D%0Ablockquote%20%7B%0Abackground%2Dcolor%3A%20%23f6f6f6%3B%0Apadding%3A%200%2E25em%200%2E75em%3B%0A%7D%0Ahr%20%7B%0Aborder%2Dstyle%3A%20solid%3B%0Aborder%3A%20none%3B%0Aborder%2Dtop%3A%201px%20solid%20%23777%3B%0Amargin%3A%2028px%200%3B%0A%7D%0Adl%20%7B%0Amargin%2Dleft%3A%200%3B%0A%7D%0Adl%20dd%20%7B%0Amargin%2Dbottom%3A%2013px%3B%0Amargin%2Dleft%3A%2013px%3B%0A%7D%0Adl%20dt%20%7B%0Afont%2Dweight%3A%20bold%3B%0A%7D%0Aul%20%7B%0Amargin%2Dtop%3A%200%3B%0A%7D%0Aul%20li%20%7B%0Alist%2Dstyle%3A%20circle%20outside%3B%0A%7D%0Aul%20ul%20%7B%0Amargin%2Dbottom%3A%200%3B%0A%7D%0Apre%2C%20code%20%7B%0Abackground%2Dcolor%3A%20%23f7f7f7%3B%0Aborder%2Dradius%3A%203px%3B%0Acolor%3A%20%23333%3B%0Awhite%2Dspace%3A%20pre%2Dwrap%3B%20%0A%7D%0Apre%20%7B%0Aborder%2Dradius%3A%203px%3B%0Amargin%3A%205px%200px%2010px%200px%3B%0Apadding%3A%2010px%3B%0A%7D%0Apre%3Anot%28%5Bclass%5D%29%20%7B%0Abackground%2Dcolor%3A%20%23f7f7f7%3B%0A%7D%0Acode%20%7B%0Afont%2Dfamily%3A%20Consolas%2C%20Monaco%2C%20%27Courier%20New%27%2C%20monospace%3B%0Afont%2Dsize%3A%2085%25%3B%0A%7D%0Ap%20%3E%20code%2C%20li%20%3E%20code%20%7B%0Apadding%3A%202px%200px%3B%0A%7D%0Adiv%2Efigure%20%7B%0Atext%2Dalign%3A%20center%3B%0A%7D%0Aimg%20%7B%0Abackground%2Dcolor%3A%20%23FFFFFF%3B%0Apadding%3A%202px%3B%0Aborder%3A%201px%20solid%20%23DDDDDD%3B%0Aborder%2Dradius%3A%203px%3B%0Aborder%3A%201px%20solid%20%23CCCCCC%3B%0Amargin%3A%200%205px%3B%0A%7D%0Ah1%20%7B%0Amargin%2Dtop%3A%200%3B%0Afont%2Dsize%3A%2035px%3B%0Aline%2Dheight%3A%2040px%3B%0A%7D%0Ah2%20%7B%0Aborder%2Dbottom%3A%204px%20solid%20%23f7f7f7%3B%0Apadding%2Dtop%3A%2010px%3B%0Apadding%2Dbottom%3A%202px%3B%0Afont%2Dsize%3A%20145%25%3B%0A%7D%0Ah3%20%7B%0Aborder%2Dbottom%3A%202px%20solid%20%23f7f7f7%3B%0Apadding%2Dtop%3A%2010px%3B%0Afont%2Dsize%3A%20120%25%3B%0A%7D%0Ah4%20%7B%0Aborder%2Dbottom%3A%201px%20solid%20%23f7f7f7%3B%0Amargin%2Dleft%3A%208px%3B%0Afont%2Dsize%3A%20105%25%3B%0A%7D%0Ah5%2C%20h6%20%7B%0Aborder%2Dbottom%3A%201px%20solid%20%23ccc%3B%0Afont%2Dsize%3A%20105%25%3B%0A%7D%0Aa%20%7B%0Acolor%3A%20%230033dd%3B%0Atext%2Ddecoration%3A%20none%3B%0A%7D%0Aa%3Ahover%20%7B%0Acolor%3A%20%236666ff%3B%20%7D%0Aa%3Avisited%20%7B%0Acolor%3A%20%23800080%3B%20%7D%0Aa%3Avisited%3Ahover%20%7B%0Acolor%3A%20%23BB00BB%3B%20%7D%0Aa%5Bhref%5E%3D%22http%3A%22%5D%20%7B%0Atext%2Ddecoration%3A%20underline%3B%20%7D%0Aa%5Bhref%5E%3D%22https%3A%22%5D%20%7B%0Atext%2Ddecoration%3A%20underline%3B%20%7D%0A%0Acode%20%3E%20span%2Ekw%20%7B%20color%3A%20%23555%3B%20font%2Dweight%3A%20bold%3B%20%7D%20%0Acode%20%3E%20span%2Edt%20%7B%20color%3A%20%23902000%3B%20%7D%20%0Acode%20%3E%20span%2Edv%20%7B%20color%3A%20%2340a070%3B%20%7D%20%0Acode%20%3E%20span%2Ebn%20%7B%20color%3A%20%23d14%3B%20%7D%20%0Acode%20%3E%20span%2Efl%20%7B%20color%3A%20%23d14%3B%20%7D%20%0Acode%20%3E%20span%2Ech%20%7B%20color%3A%20%23d14%3B%20%7D%20%0Acode%20%3E%20span%2Est%20%7B%20color%3A%20%23d14%3B%20%7D%20%0Acode%20%3E%20span%2Eco%20%7B%20color%3A%20%23888888%3B%20font%2Dstyle%3A%20italic%3B%20%7D%20%0Acode%20%3E%20span%2Eot%20%7B%20color%3A%20%23007020%3B%20%7D%20%0Acode%20%3E%20span%2Eal%20%7B%20color%3A%20%23ff0000%3B%20font%2Dweight%3A%20bold%3B%20%7D%20%0Acode%20%3E%20span%2Efu%20%7B%20color%3A%20%23900%3B%20font%2Dweight%3A%20bold%3B%20%7D%20%0Acode%20%3E%20span%2Eer%20%7B%20color%3A%20%23a61717%3B%20background%2Dcolor%3A%20%23e3d2d2%3B%20%7D%20%0A" type="text/css" />




</head>

<body>




<h1 class="title toc-ignore">Loading data</h1>



<div class="sourceCode" id="cb1"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="fu">library</span>(torch)</span></code></pre></div>
<div id="datasets-and-data-loaders" class="section level2">
<h2>Datasets and data loaders</h2>
<p>Central to data ingestion and preprocessing are datasets and data loaders.</p>
<p><code>torch</code> comes equipped with a bag of datasets related to, mostly, image recognition and natural language processing (e.g., <code>mnist_dataset()</code>), which can be iterated over by means of <code>dataloader</code>s:</p>
<div class="sourceCode" id="cb2"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># ...</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>ds <span class="ot">&lt;-</span> <span class="fu">mnist_dataset</span>(</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>  dir, </span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>  <span class="at">download =</span> <span class="cn">TRUE</span>, </span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>  <span class="at">transform =</span> <span class="cf">function</span>(x) {</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    x <span class="ot">&lt;-</span> x<span class="sc">$</span><span class="fu">to</span>(<span class="at">dtype =</span> <span class="fu">torch_float</span>())<span class="sc">/</span><span class="dv">256</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    x[newaxis,..]</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>  }</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>dl <span class="ot">&lt;-</span> <span class="fu">dataloader</span>(ds, <span class="at">batch_size =</span> <span class="dv">32</span>, <span class="at">shuffle =</span> <span class="cn">TRUE</span>)</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (b <span class="cf">in</span> <span class="fu">enumerate</span>(dl)) {</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>  <span class="co"># ...</span></span></code></pre></div>
<p>Cf. <code>vignettes/examples/mnist-cnn.R</code> for a complete example.</p>
<p>What if you want to train on a different dataset? In these cases, you subclass <code>Dataset</code>, an abstract container that needs to know how to iterate over the given data. To that purpose, your subclass needs to implement <code>.getitem()</code>, and say what should be returned when the data loader is asking for the next batch.</p>
<p>In <code>.getitem()</code>, you can implement whatever preprocessing you require. Additionally, you should implement <code>.length()</code>, so users can find out how many items there are in the dataset.</p>
<p>While this may sound complicated, it is not at all. The base logic is straightforward – complexity will, naturally, correlate with how involved your preprocessing is. To provide you with a simple but functional prototype, here we show how to create your own <code>dataset</code> to train on <a href="https://github.com/allisonhorst/palmerpenguins">Allison Horst&#39;s penguins</a>.</p>
</div>
<div id="a-custom-dataset" class="section level2">
<h2>A custom dataset</h2>
<div class="sourceCode" id="cb3"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="fu">library</span>(palmerpenguins)</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="fu">library</span>(magrittr)</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>penguins</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; # A tibble: 344 x 8</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;    &lt;fct&gt;   &lt;fct&gt;              &lt;dbl&gt;         &lt;dbl&gt;             &lt;int&gt;       &lt;int&gt;</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1 Adelie  Torgersen           39.1          18.7               181        3750</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  2 Adelie  Torgersen           39.5          17.4               186        3800</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  3 Adelie  Torgersen           40.3          18                 195        3250</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  4 Adelie  Torgersen           NA            NA                  NA          NA</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  5 Adelie  Torgersen           36.7          19.3               193        3450</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  6 Adelie  Torgersen           39.3          20.6               190        3650</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  7 Adelie  Torgersen           38.9          17.8               181        3625</span></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  8 Adelie  Torgersen           39.2          19.6               195        4675</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  9 Adelie  Torgersen           34.1          18.1               193        3475</span></span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; 10 Adelie  Torgersen           42            20.2               190        4250</span></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; # … with 334 more rows, and 2 more variables: sex &lt;fct&gt;, year &lt;int&gt;</span></span></code></pre></div>
<p><code>Dataset</code>s are R6 classes created using the <code>dataset()</code> constructor. You can pass a <code>name</code> and various member functions. Among those should be <code>initialize()</code>, to create instance variables, <code>.getitem()</code>, to indicate how the data should be returned, and <code>.length()</code>, to say how many items we have.</p>
<p>In addition, any number of helper functions can be defined.</p>
<p>Here, we assume the <code>penguins</code> have already been loaded, and all preprocessing consists in removing lines with <code>NA</code> values, transforming <code>factor</code>s to numbers starting from 0, and converting from R data types to <code>torch</code> tensors.</p>
<p>In <code>.getitem</code>, we essentially decide how this data is going to be used: All variables besides <code>species</code> go into <code>x</code>, the predictor, and <code>species</code> will constitute <code>y</code>, the target. Predictor and target are returned in a list, to be accessed as <code>batch[[1]]</code> and <code>batch[[2]]</code> during training.</p>
<div class="sourceCode" id="cb4"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a>penguins_dataset <span class="ot">&lt;-</span> <span class="fu">dataset</span>(</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>  <span class="at">name =</span> <span class="st">&quot;penguins_dataset&quot;</span>,</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>  <span class="at">initialize =</span> <span class="cf">function</span>() {</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    self<span class="sc">$</span>data <span class="ot">&lt;-</span> self<span class="sc">$</span><span class="fu">prepare_penguin_data</span>()</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>  },</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>  <span class="at">.getitem =</span> <span class="cf">function</span>(index) {</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    x <span class="ot">&lt;-</span> self<span class="sc">$</span>data[index, <span class="dv">2</span><span class="sc">:-</span><span class="dv">1</span>]</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    y <span class="ot">&lt;-</span> self<span class="sc">$</span>data[index, <span class="dv">1</span>]<span class="sc">$</span><span class="fu">to</span>(<span class="fu">torch_long</span>())</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    <span class="fu">list</span>(x, y)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>  },</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>  <span class="at">.length =</span> <span class="cf">function</span>() {</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    self<span class="sc">$</span>data<span class="sc">$</span><span class="fu">size</span>()[[<span class="dv">1</span>]]</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>  },</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>  <span class="at">prepare_penguin_data =</span> <span class="cf">function</span>() {</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    input <span class="ot">&lt;-</span> <span class="fu">na.omit</span>(penguins) </span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># conveniently, the categorical data are already factors</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    input<span class="sc">$</span>species <span class="ot">&lt;-</span> <span class="fu">as.numeric</span>(input<span class="sc">$</span>species)</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    input<span class="sc">$</span>island <span class="ot">&lt;-</span> <span class="fu">as.numeric</span>(input<span class="sc">$</span>island)</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    input<span class="sc">$</span>sex <span class="ot">&lt;-</span> <span class="fu">as.numeric</span>(input<span class="sc">$</span>sex)</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>    input <span class="ot">&lt;-</span> <span class="fu">as.matrix</span>(input)</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    <span class="fu">torch_tensor</span>(input)</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>  }</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div>
<p>Let’s create the dataset , query for it’s length, and look at its first item:</p>
<div class="sourceCode" id="cb5"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>tuxes <span class="ot">&lt;-</span> <span class="fu">penguins_dataset</span>()</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>tuxes<span class="sc">$</span><span class="fu">.length</span>()</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [1] 333</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>tuxes<span class="sc">$</span><span class="fu">.getitem</span>(<span class="dv">1</span>)</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [[1]]</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; torch_tensor</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;    39.1000</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;    18.7000</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;   181.0000</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  3750.0000</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     2.0000</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  2007.0000</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [ CPUFloatType{7} ]</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; </span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [[2]]</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; torch_tensor</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; 1</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [ CPULongType{} ]</span></span></code></pre></div>
<p>To be able to iterate over <code>tuxes</code>, we need a data loader (we override the default batch size of <code>1</code>):</p>
<div class="sourceCode" id="cb6"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a>dl <span class="ot">&lt;-</span>tuxes <span class="sc">%&gt;%</span> <span class="fu">dataloader</span>(<span class="at">batch_size =</span> <span class="dv">8</span>)</span></code></pre></div>
<p>Calling <code>.length()</code> on a data loader (as opposed to a dataset) will return the number of <code>batches</code> we have:</p>
<div class="sourceCode" id="cb7"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a>dl<span class="sc">$</span><span class="fu">.length</span>()</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [1] 42</span></span></code></pre></div>
<p>And we can create an iterator to inspect the first batch:</p>
<div class="sourceCode" id="cb8"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a>iter <span class="ot">&lt;-</span> dl<span class="sc">$</span><span class="fu">.iter</span>()</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>b <span class="ot">&lt;-</span> iter<span class="sc">$</span><span class="fu">.next</span>()</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>b</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [[1]]</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; torch_tensor</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    39.1000    18.7000   181.0000  3750.0000     2.0000  2007.0000</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    39.5000    17.4000   186.0000  3800.0000     1.0000  2007.0000</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    40.3000    18.0000   195.0000  3250.0000     1.0000  2007.0000</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    36.7000    19.3000   193.0000  3450.0000     1.0000  2007.0000</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    39.3000    20.6000   190.0000  3650.0000     2.0000  2007.0000</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    38.9000    17.8000   181.0000  3625.0000     1.0000  2007.0000</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    39.2000    19.6000   195.0000  4675.0000     2.0000  2007.0000</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;     3.0000    41.1000    17.6000   182.0000  3200.0000     1.0000  2007.0000</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [ CPUFloatType{8,7} ]</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; </span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [[2]]</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; torch_tensor</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt;  1</span></span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; [ CPULongType{8} ]</span></span></code></pre></div>
<p>To train a network, we can use <code>enumerate</code> to iterate over batches.</p>
</div>
<div id="training-with-data-loaders" class="section level2">
<h2>Training with data loaders</h2>
<p>Our example network is very simple. (In reality, we would want to treat <code>island</code> as the categorical variable it is, and either one-hot-encode or embed it.)</p>
<div class="sourceCode" id="cb9"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a>net <span class="ot">&lt;-</span> <span class="fu">nn_module</span>(</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>  <span class="st">&quot;PenguinNet&quot;</span>,</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>  <span class="at">initialize =</span> <span class="cf">function</span>() {</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    self<span class="sc">$</span>fc1 <span class="ot">&lt;-</span> <span class="fu">nn_linear</span>(<span class="dv">7</span>, <span class="dv">32</span>)</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    self<span class="sc">$</span>fc2 <span class="ot">&lt;-</span> <span class="fu">nn_linear</span>(<span class="dv">32</span>, <span class="dv">3</span>)</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>  },</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>  <span class="at">forward =</span> <span class="cf">function</span>(x) {</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    x <span class="sc">%&gt;%</span> </span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>      self<span class="sc">$</span><span class="fu">fc1</span>() <span class="sc">%&gt;%</span> </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>      <span class="fu">nnf_relu</span>() <span class="sc">%&gt;%</span> </span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>      self<span class="sc">$</span><span class="fu">fc2</span>() <span class="sc">%&gt;%</span> </span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>      <span class="fu">nnf_log_softmax</span>(<span class="at">dim =</span> <span class="dv">1</span>)</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>  }</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>model <span class="ot">&lt;-</span> <span class="fu">net</span>()</span></code></pre></div>
<p>We still need an optimizer:</p>
<div class="sourceCode" id="cb10"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a>optimizer <span class="ot">&lt;-</span> <span class="fu">optim_sgd</span>(model<span class="sc">$</span>parameters, <span class="at">lr =</span> <span class="fl">0.01</span>)</span></code></pre></div>
<p>And we’re ready to train:</p>
<div class="sourceCode" id="cb11"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (epoch <span class="cf">in</span> <span class="dv">1</span><span class="sc">:</span><span class="dv">10</span>) {</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>  l <span class="ot">&lt;-</span> <span class="fu">c</span>()</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>  <span class="cf">for</span> (b <span class="cf">in</span> <span class="fu">enumerate</span>(dl)) {</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    optimizer<span class="sc">$</span><span class="fu">zero_grad</span>()</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    output <span class="ot">&lt;-</span> <span class="fu">model</span>(b[[<span class="dv">1</span>]])</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    loss <span class="ot">&lt;-</span> <span class="fu">nnf_nll_loss</span>(output, b[[<span class="dv">2</span>]])</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    loss<span class="sc">$</span><span class="fu">backward</span>()</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    optimizer<span class="sc">$</span><span class="fu">step</span>()</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    l <span class="ot">&lt;-</span> <span class="fu">c</span>(l, loss<span class="sc">$</span><span class="fu">item</span>())</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>  }</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>  </span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>  <span class="fu">cat</span>(<span class="fu">sprintf</span>(<span class="st">&quot;Loss at epoch %d: %3f</span><span class="sc">\n</span><span class="st">&quot;</span>, epoch, <span class="fu">mean</span>(l)))</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 1: 45.339662</span></span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 2: 2.068251</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 3: 2.068251</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 4: 2.068251</span></span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 5: 2.068251</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 6: 2.068251</span></span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 7: 2.068251</span></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 8: 2.068251</span></span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 9: 2.068251</span></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a><span class="co">#&gt; Loss at epoch 10: 2.068251</span></span></code></pre></div>
</div>



<!-- code folding -->


<!-- dynamically load mathjax for compatibility with self-contained -->
<script>
  (function () {
    var script = document.createElement("script");
    script.type = "text/javascript";
    script.src  = "https://mathjax.rstudio.com/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML";
    document.getElementsByTagName("head")[0].appendChild(script);
  })();
</script>

</body>
</html>
